import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import scipy.sparse
import scipy.io
from . import file_utils
from typing import List, Tuple, Union, Mapping, Any, Callable, Iterable
from sentence_transformers import SentenceTransformer
import os


def load_contextual_embed(
    texts, device="cuda", model_name="all-mpnet-base-v2", show_progress_bar=True
):
    model = SentenceTransformer(model_name, device=device)
    embeddings = model.encode(texts, show_progress_bar=show_progress_bar)
    print(embeddings.shape)
    return embeddings


class _SequentialDataset(Dataset):
    def __init__(
        self,
        bow,
        contextual_embed,
        times,
        time_wordfreq,
    ):
        super().__init__()
        self.bow = bow
        self.times = times
        self.time_wordfreq = time_wordfreq
        self.S_Emb = contextual_embed

    def __len__(self):
        return len(self.bow)

    def __getitem__(self, index):
        if isinstance(self.S_Emb[index], torch.Tensor):
            S_Emb = self.S_Emb[index]
        else:
            S_Emb = torch.FloatTensor(self.S_Emb[index])

        return_dict = {
            "bow": self.bow[index],
            "s_emb": S_Emb,
            "times": self.times[index],
            "time_wordfreq": self.time_wordfreq[self.times[index]],
        }

        return return_dict


class DynamicDataset:
    def __init__(
        self,
        dataset_dir,
        batch_size=200,
        read_labels=False,
        device="cpu",
        as_tensor=True,
        contextual_embed=False,
        doc_embed_model="all-MiniLM-L6-v2",
    ):
        self.load_data(dataset_dir, read_labels)

        self.vocab_size = len(self.vocab)
        self.train_size = len(self.train_bow)
        self.num_times = len(np.unique(self.train_times))
        self.train_time_wordfreq = self.get_time_wordfreq(
            self.train_bow, self.train_times
        )

        print("train size: ", len(self.train_bow))
        print("test size: ", len(self.test_bow))
        print("vocab size: ", len(self.vocab))
        print("average length: {:.3f}".format(self.train_bow.sum(1).mean().item()))
        print("num of each time slice: ", self.num_times, np.bincount(self.train_times))

        if contextual_embed:
            if os.path.isfile(os.path.join(dataset_dir, "with_bert", "train_bert.npy")):
                self.train_contextual_embed = np.load(
                    os.path.join(dataset_dir, "with_bert", "train_bert.npy")
                ).astype("float32")
                self.test_contextual_embed = np.load(
                    os.path.join(dataset_dir, "with_bert", "test_bert.npy")
                ).astype("float32")
            else:
                self.train_contextual_embed = load_contextual_embed(
                    self.train_texts, model_name=doc_embed_model
                )
                self.test_contextual_embed = load_contextual_embed(
                    self.test_texts, model_name=doc_embed_model
                )
                os.makedirs(os.path.join(dataset_dir, "with_bert"), exist_ok=True)
                np.save(
                    os.path.join(dataset_dir, "with_bert", "train_bert"),
                    self.train_contextual_embed,
                )
                np.save(
                    os.path.join(dataset_dir, "with_bert", "test_bert"),
                    self.test_contextual_embed,
                )
            self.contextual_embed_size = self.train_contextual_embed.shape[1]

        if as_tensor:
            self.train_bow = torch.from_numpy(self.train_bow).float().to(device)
            self.test_bow = torch.from_numpy(self.test_bow).float().to(device)
            self.train_times = torch.from_numpy(self.train_times).long().to(device)
            self.test_times = torch.from_numpy(self.test_times).long().to(device)
            self.train_time_wordfreq = (
                torch.from_numpy(self.train_time_wordfreq).float().to(device)
            )
            self.train_contextual_embed = (
                torch.from_numpy(self.train_contextual_embed).float().to(device)
            )
            self.test_contextual_embed = (
                torch.from_numpy(self.test_contextual_embed).float().to(device)
            )
            self.train_dataset = _SequentialDataset(
                self.train_bow,
                self.train_contextual_embed,
                self.train_times,
                self.train_time_wordfreq,
            )
            self.test_dataset = _SequentialDataset(
                self.test_bow,
                self.test_contextual_embed,
                self.test_times,
                self.train_time_wordfreq,
            )

            self.train_dataloader = DataLoader(
                self.train_dataset, batch_size=batch_size, shuffle=True
            )
            self.test_dataloader = DataLoader(
                self.test_dataset, batch_size=batch_size, shuffle=False
            )

    def load_data(self, path, read_labels):
        self.train_bow = (
            scipy.sparse.load_npz(f"{path}/train_bow.npz").toarray().astype("float32")
        )
        self.test_bow = (
            scipy.sparse.load_npz(f"{path}/test_bow.npz").toarray().astype("float32")
        )
        self.word_embeddings = (
            scipy.sparse.load_npz(f"{path}/word_embeddings.npz")
            .toarray()
            .astype("float32")
        )

        self.train_texts = file_utils.read_text(f"{path}/train_texts.txt")
        self.test_texts = file_utils.read_text(f"{path}/test_texts.txt")

        self.train_times = np.loadtxt(f"{path}/train_times.txt").astype("int32")
        self.test_times = np.loadtxt(f"{path}/test_times.txt").astype("int32")

        self.vocab = file_utils.read_text(f"{path}/vocab.txt")

        self.pretrained_WE = (
            scipy.sparse.load_npz(f"{path}/word_embeddings.npz")
            .toarray()
            .astype("float32")
        )

        if read_labels:
            self.train_labels = np.loadtxt(f"{path}/train_labels.txt").astype("int32")
            self.test_labels = np.loadtxt(f"{path}/test_labels.txt").astype("int32")

    # word frequency at each time slice.
    def get_time_wordfreq(self, bow, times):
        train_time_wordfreq = np.zeros((self.num_times, self.vocab_size))
        for time in range(self.num_times):
            idx = np.where(times == time)[0]
            train_time_wordfreq[time] += bow[idx].sum(0)
        cnt_times = np.bincount(times)

        train_time_wordfreq = train_time_wordfreq / cnt_times[:, np.newaxis]
        return train_time_wordfreq
